{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/llm/nvidia_tensorrt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nvidia TensorRT-LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.\n",
    "\n",
    "[TensorRT-LLM Github](https://github.com/NVIDIA/TensorRT-LLM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorRT-LLM Environment Setup\n",
    "Since TensorRT-LLM is a SDK for interacting with local models in process there are a few environment steps that must be followed to ensure that the TensorRT-LLM setup can be used.\n",
    "\n",
    "1. Nvidia Cuda 12.2 or higher is currently required to run TensorRT-LLM\n",
    "2. Install `tensorrt_llm` via pip with `pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com`\n",
    "3. For this example we will use Llama2. The Llama2 model files need to be created via scripts following the instructions [here](https://github.com/NVIDIA/trt-llm-rag-windows/blob/release/1.0/README.md#building-trt-engine)\n",
    "    * The following files will be created from following the stop above\n",
    "    * `Llama_float16_tp1_rank0.engine`: The main output of the build script, containing the executable graph of operations with the model weights embedded.\n",
    "    * `config.json`: Includes detailed information about the model, like its general structure and precision, as well as information about which plug-ins were incorporated into the engine.\n",
    "    * `model.cache`: Caches some of the timing and optimization information from model compilation, making successive builds quicker.\n",
    "4. `mkdir model`\n",
    "5. Move all of the files mentioned above to the model directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic Usage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Call `complete` with a prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.llms import LocalTensorRTLLM\n",
    "\n",
    "\n",
    "def completion_to_prompt(completion: str) -> str:\n",
    "    \"\"\"\n",
    "    Given a completion, return the prompt using llama2 format.\n",
    "    \"\"\"\n",
    "    return f\"<s> [INST] {completion} [/INST] \"\n",
    "\n",
    "\n",
    "llm = LocalTensorRTLLM(\n",
    "    model_path=\"./model\",\n",
    "    engine_name=\"llama_float16_tp1_rank0.engine\",\n",
    "    tokenizer_dir=\"meta-llama/Llama-2-13b-chat\",\n",
    "    completion_to_prompt=completion_to_prompt,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resp = llm.complete(\"Who is Paul Graham?\")\n",
    "print(str(resp))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  },
  "vscode": {
   "interpreter": {
    "hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
